<!doctype html>
<html lang="en">

<head>
  <meta charset="utf-8">
  <meta http-equiv="X-UA-Compatible" content="IE=edge,chrome=1">
  <title>PuckWorld</title>
  <meta name="description" content="">
  <meta name="author" content="">
  <meta name="viewport" content="width=device-width, initial-scale=1.0">

  <!-- jquery and jqueryui -->
  <script src="external/jquery-2.1.3.min.js"></script>
  <link href="external/jquery-ui.min.css" rel="stylesheet">
  <script src="external/jquery-ui.min.js"></script>

  <!-- bootstrap -->
  <script src="external/bootstrap.min.js"></script>
  <link href="external/bootstrap.min.css" rel="stylesheet">

  <!-- d3js -->
  <script type="text/javascript" src="external/d3.min.js"></script>

  <!-- markdown -->
  <script type="text/javascript" src="external/marked.js"></script>
  <script type="text/javascript" src="external/highlight.pack.js"></script>
  <link rel="stylesheet" href="external/highlight_default.css">
  <script>
    hljs.initHighlightingOnLoad();
  </script>

  <!-- mathjax: nvm now loaded dynamically
  <script type="text/javascript" src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
  -->

  <!-- rljs -->
  <script type="text/javascript" src="lib/rl.js"></script>

  <!-- flotjs -->
  <script src="external/jquery.flot.min.js"></script>

  <!-- GA -->
  <script>
    (function (i, s, o, g, r, a, m) {
    i['GoogleAnalyticsObject'] = r; i[r] = i[r] || function () {
      (i[r].q = i[r].q || []).push(arguments)
    }, i[r].l = 1 * new Date(); a = s.createElement(o),
      m = s.getElementsByTagName(o)[0]; a.async = 1; a.src = g; m.parentNode.insertBefore(a, m)
    })(window, document, 'script', '//www.google-analytics.com/analytics.js', 'ga');
    ga('create', 'UA-3698471-24', 'auto');
    ga('send', 'pageview');
  </script>

  <style>
    #wrap {
      width: 800px;
      margin-left: auto;
      margin-right: auto;
    }

    h2 {
      text-align: center;
    }

    body {
      font-family: Arial, "Helvetica Neue", Helvetica, sans-serif;
    }

    #draw {
      margin-left: 100px;
    }

    svg {
      border: 1px solid black;
    }
  </style>

  <script type="application/javascript">
    var W = 600, H = 600;
    var d3line = null;
    var d3agent = null;
    var d3target = null;
    var d3target2 = null;
    var d3target2_radius = null;
    var initDraw = function () {
      var d3elt = d3.select('#draw');
      d3elt.html('');

      var w = 600;
      var h = 600;
      svg = d3elt.append('svg').attr('width', w).attr('height', h)
        .append('g').attr('transform', 'scale(1)');

      // define a marker for drawing arrowheads
      svg.append("defs").append("marker")
        .attr("id", "arrowhead")
        .attr("refX", 3)
        .attr("refY", 2)
        .attr("markerWidth", 3)
        .attr("markerHeight", 4)
        .attr("orient", "auto")
        .append("path")
        .attr("d", "M 0,0 V 4 L3,2 Z");

      // draw the puck
      d3agent = svg.append('circle')
        .attr('cx', 100)
        .attr('cy', 100)
        .attr('r', env.rad * this.W)
        .attr('fill', '#FF0')
        .attr('stroke', '#000')
        .attr('id', 'puck');

      // draw the target
      d3target = svg.append('circle')
        .attr('cx', 200)
        .attr('cy', 200)
        .attr('r', 10)
        .attr('fill', '#0F0')
        .attr('stroke', '#000')
        .attr('id', 'target');

      // bad target
      d3target2 = svg.append('circle')
        .attr('cx', 300)
        .attr('cy', 300)
        .attr('r', 10)
        .attr('fill', '#F00')
        .attr('stroke', '#000')
        .attr('id', 'target2');

      d3target2_radius = svg.append('circle')
        .attr('cx', 300)
        .attr('cy', 300)
        .attr('r', 10)
        .attr('fill', 'rgba(255,0,0,0.1)')
        .attr('stroke', '#000');

      // draw line indicating forces
      d3line = svg.append('line')
        .attr('x1', 0)
        .attr('y1', 0)
        .attr('x2', 0)
        .attr('y2', 0)
        .attr('stroke', 'black')
        .attr('stroke-width', '2')
        .attr("marker-end", "url(#arrowhead)");


    }

    var updateDraw = function (a, s, r) {
      // reflect puck world state on screen
      var ppx = env.ppx; var ppy = env.ppy;
      var tx = env.tx; var ty = env.ty;
      var tx2 = env.tx2; var ty2 = env.ty2;

      d3agent.attr('cx', ppx * W).attr('cy', ppy * H);
      d3target.attr('cx', tx * W).attr('cy', ty * H);
      d3target2.attr('cx', tx2 * W).attr('cy', ty2 * H);
      d3target2_radius.attr('cx', tx2 * W).attr('cy', ty2 * H).attr('r', env.BADRAD * H);
      d3line.attr('x1', ppx * W).attr('y1', ppy * H).attr('x2', ppx * W).attr('y2', ppy * H);
      var af = 20;
      d3line.attr('visibility', a[0] === 0 && a[1] ===0 ? 'hidden' : 'visible');
      
      d3line.attr('x2', ppx * W + af * a[0]);
      d3line.attr('y2', ppy * H + af * a[1]);
 
      // color agent by reward
      var vv = r + 0.5;
      var ms = 255.0;
      if (vv > 0) { g = 255; r = 255 - vv * ms; b = 255 - vv * ms; }
      if (vv < 0) { g = 255 + vv * ms; r = 255; b = 255 + vv * ms; }
      var vcol = 'rgb(' + Math.floor(r) + ',' + Math.floor(g) + ',' + Math.floor(b) + ')';
      d3agent.attr('fill', vcol);
    }
    /////////////////////////////////////////////////
    var PuckWorld = function () {
      // 按次序提供状态和行为特征描述
      this.stateFeatures = [
        "个体位置X",
        "个体位置Y",
        "个体速度X",
        "个体速度Y",
        "个体目标X",
        "个体目标Y",
        "惩罚位置X",
        "惩罚位置Y"
      ];
      this.actionFeatures = [
        "水平",
        "竖直"
      ];
      this.reset();
    }
    PuckWorld.prototype = {
      reset: function () {
        this.ppx = Math.random(); // puck 个体
        this.ppy = Math.random();
        this.pvx = Math.random() * 0.05 - 0.025; // 个体速率
        this.pvy = Math.random() * 0.05 - 0.025;
        this.tx = Math.random(); // 个体要追踪的目标
        this.ty = Math.random();
        this.tx2 = Math.random(); // 要避免的目标？
        this.ty2 = Math.random(); // target
        this.rad = 0.05;    // 个体半径，对奖励没有关系，只为直观显示
        this.t = 0;         // 时间

        this.BADRAD = 0.25; // 红球惩罚范围

        this.reward1 = 0.0; // 个体即时奖励
      },
      getNumOfStateFeatures: function () { // 状态特征数
        return 8; // x,y,vx,vy, puck dx,dy
      },
      getMaxNumOfActionFeatures: function () { // 行为特征数
        return 2; // left, right, up, down, nothing
      },
      // 得到当前全部状态
      getState: function () {
        var s = [
          this.ppx - 0.5, // 个体相对区域中心相对位置
          this.ppy - 0.5,
          this.pvx * 10,  // 个体速度
          this.pvy * 10,
          this.tx - this.ppx, // 追踪目标位置相对个体位置
          this.ty - this.ppy,
          this.tx2 - this.ppx,  // 惩罚大球位置相对个体位置
          this.ty2 - this.ppy
        ];
        return s;
      },
      // 给个体释放的信息
      releasedObs: function () {
        // 有时候希望对个体隐藏一些信息
        return this.getState();
      },
      // 环境模型，动力学,接受个体的行为,更新环境
      // 这里的a是一个向量，表示水平和垂直方向加速度的大小
      dynamics: function (a) {

        // world dynamics
        this.ppx += this.pvx; // 根据速度更新位置
        this.ppy += this.pvy;
        this.pvx *= 0.95; // damping 个体速度自然衰减
        this.pvy *= 0.95;

        // agent action influences puck velocity
        var accel = 0.002;
        this.pvx += accel * a[0];
        this.pvy += accel * a[1];
        
        // handle boundary conditions and bounce
        if (this.ppx < this.rad) {       // 个体中心横坐标比其半径小，撞到左边界了
          this.pvx *= -0.5; // bounce!  // 反弹速度有损失
          this.ppx = this.rad;          // 重新确立其横坐标
        }
        if (this.ppx > 1 - this.rad) {   // 1为活动区域总宽度，撞到右边界
          this.pvx *= -0.5;
          this.ppx = 1 - this.rad;
        }
        if (this.ppy < this.rad) {       // 上边界？
          this.pvy *= -0.5; // bounce!
          this.ppy = this.rad;
        }
        if (this.ppy > 1 - this.rad) {   // 下边界？
          this.pvy *= -0.5;
          this.ppy = 1 - this.rad;
        }

        this.t += 1;
        if (this.t % 100 === 0) {        // 隔一定时间就刷新依次目标球位置
          this.tx = Math.random(); // reset the target location
          this.ty = Math.random();
        }
        // compute distances
        var dx = this.ppx - this.tx;
        var dy = this.ppy - this.ty;
        var d1 = Math.sqrt(dx * dx + dy * dy);  // 个体与目标球距离

        var dx = this.ppx - this.tx2;
        var dy = this.ppy - this.ty2;
        var d2 = Math.sqrt(dx * dx + dy * dy)+0.00000000001;  // 个体与惩罚球距离

        var dxnorm = dx / (d2);
        var dynorm = dy / (d2);
        var speed = 0.0005;
        this.tx2 += speed * dxnorm;       // 更新惩罚球位置，惩罚球永远直线逼近个体
        this.ty2 += speed * dynorm;

        // compute reward
        var r = -d1;      // 个体距离目标球越远，奖励越低
        if (d2 < this.BADRAD) {    // 如果与惩罚球的距离过近，则奖励更低
          // but if we're too close to red that's bad
          r += 2 * (d2 - this.BADRAD) / this.BADRAD;
        }

        //if(a === 4) r += 0.05; // give bonus for gliding with no force
        
        return r;
        // this.reward1 = r;
        // evolve state in time
        // var ns = this.getState(); // 获取新状态
        // var out = { 'ns': ns, 'r': r };
        // return out;     // 这是环境给以的整体输出，新状态，奖励
      }
    }

    /////////////////////////////////////////////////////////

    function gofast() { steps_per_tick = 100; }
    function gonormal() { steps_per_tick = 10; }
    function goslow() { steps_per_tick = 1; }

    // flot stuff 绘曲线图
    var nflot = 1000;
    function initFlot() {
      var container = $("#flotreward");
      var res = getFlotRewards();
      series = [{
        data: res,
        lines: { fill: true }
      }];
      var plot = $.plot(container, series, {
        grid: {
          borderWidth: 1,
          minBorderMargin: 20,
          labelMargin: 10,
          backgroundColor: {
            colors: ["#FFF", "#e4f4f4"]
          },
          margin: {
            top: 10,
            bottom: 10,
            left: 10,
          }
        },
        xaxis: {
          min: 0,
          max: nflot
        },
        yaxis: {
          min: -2,
          max: 1
        }
      });
      setInterval(function () {
        series[0].data = getFlotRewards();
        plot.setData(series);
        plot.draw();
      }, 100);
    }
    function getFlotRewards() {
      // zip rewards into flot data
      var res = [];
      for (var i = 0, n = smooth_reward_history.length; i < n; i++) {
        res.push([i, smooth_reward_history[i]]);
      }
      return res;
    }

    var steps_per_tick = 1; // 控制运行速度
    var sid = -1;           // 是否暂停标记
    var action, envState;
    var smooth_reward_history = [];
    var smooth_reward = null;
    var flott = 0;
    function togglelearn() {
      if (sid === -1) {
        sid = setInterval(function () {

          for (var k = 0; k < steps_per_tick; k++) {
            envState = env.getState();
            agent.state = agent.observe();
            var action = agent.performPolicy(agent.state);
            agent.performAction(action);
            var reward = agent.reward1;
            agent.learn();
            if (smooth_reward == null) { smooth_reward = reward; }
            smooth_reward = smooth_reward * 0.999 + reward * 0.001;
            flott += 1;
            if (flott === 200) {
              // record smooth reward
              if (smooth_reward_history.length >= nflot) {
                smooth_reward_history = smooth_reward_history.slice(1);
              }
              smooth_reward_history.push(smooth_reward);
              flott = 0;
            }
          }

          updateDraw(action, envState, reward);
          if (typeof agent.expi !== 'undefined') {
            $("#expi").html(agent.expi);
          }
          if (typeof agent.tderror !== 'undefined') {
            $("#tde").html(agent.tderror.toFixed(3));
          }
          if (typeof agent.tderror !== 'undefined') {
            $("#qsa").html(agent.qsa.toFixed(3));
          }
          //$("#tdest").html('tderror: ' + agent.tderror_estimator.getMean().toFixed(4) + ' +/- ' + agent.tderror_estimator.getStd().toFixed(4));
        }, 20);
      } else {
        clearInterval(sid);
        sid = -1;
      }
    }

    function saveAgent() {
      $("#mysterybox").fadeIn();
      $("#mysterybox").val(JSON.stringify(agent.toJSON()));
    }

    function loadAgent() {
      $.getJSON("agentzoo/puckagent.json", function (data) {
        agent.fromJSON(data); // corss your fingers...
        // set epsilon to be much lower for more optimal behavior
        agent.epsilon = 0.05;
        $("#slider").slider('value', agent.epsilon);
        $("#eps").html(agent.epsilon.toFixed(2));
        // kill learning rate to not learn 不学习
        agent.alpha = 0;
        Console.log("agent loaded");
      });
    }

    function resetAgent() {
      eval($("#agentspec").val())
      agent = new RL.GaussianActorCriticAgent(env, spec);
    }

    var w; // global world object
    var current_interval_id;
    var skipdraw = false;
    function start() {

      env = new PuckWorld();

      initDraw();

      eval($("#agentspec").val())

      //agent = new RL.ActorCritic(env, spec);
      //agent = new RL.DQNAgent(env, spec);
      agent = new RL.GaussianActorCriticAgent(env,spec);
      //agent = new RL.RecurrentReinforceAgent(env, {});
      initFlot();

      // slider sets agent epsilon
      $("#slider").slider({
        min: 0,
        max: 1,
        value: agent.epsilon,
        step: 0.01,
        slide: function (event, ui) {
          agent.epsilon = ui.value;
          $("#eps").html(ui.value.toFixed(2));
        }
      });
      $("#eps").html(agent.epsilon.toFixed(2));

      togglelearn(); // start

      // render markdown
      $(".md").each(function () {
        $(this).html(marked($(this).html()));
      });
      renderJax();

    }

    var jaxrendered = false;
    function renderJax() {
      if (jaxrendered) { return; }
      (function () {
        var script = document.createElement("script");
        script.type = "text/javascript";
        script.src = "http://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML";
        document.getElementsByTagName("head")[0].appendChild(script);
        jaxrendered = true;
      })();
    }
  </script>
  <style type="text/css">
    canvas {
      border: 1px solid white;
    }
  </style>

</head>

<body onload="start();">

  <a href="https://github.com/karpathy/reinforcejs"><img style="position: absolute; top: 0; right: 0; border: 0;" src="https://camo.githubusercontent.com/38ef81f8aca64bb9a64448d0d70f1308ef5341ab/68747470733a2f2f73332e616d617a6f6e6177732e636f6d2f6769746875622f726962626f6e732f666f726b6d655f72696768745f6461726b626c75655f3132313632312e706e67" alt="Fork me on GitHub" data-canonical-src="https://s3.amazonaws.com/github/ribbons/forkme_right_darkblue_121621.png"></a>

  <div id="wrap">

    <div id="mynav" style="border-bottom:1px solid #999; padding-bottom: 10px; margin-bottom:50px;">
      <div>
        <img src="loop.svg" style="width:50px;height:50px;float:left;">
        <h1 style="font-size:50px;">REINFORCE<span style="color:#058;">js</span></h1>
      </div>
      <ul class="nav nav-pills">
        <li role="presentation"><a href="index.html">About</a></li>
        <li role="presentation"><a href="gridworld_dp.html">GridWorld: DP</a></li>
        <li role="presentation"><a href="gridworld_td.html">GridWorld: TD</a></li>
        <li role="presentation" class="active"><a href="puckworld.html">PuckWorld: DQN</a></li>
        <li role="presentation"><a href="waterworld.html">WaterWorld: DQN</a></li>
      </ul>
    </div>

    <h2>PuckWorld: Deep Q Learning</h2>

    <textarea id="agentspec" style="width:100%;height:230px;">
// agent parameter spec to play with (this gets eval()'d on Agent reset)
var spec = {}
spec.update = 'qlearn'; // qlearn | sarsa
spec.gamma = 0.9; // discount factor, [0, 1)
spec.epsilon = 0.2; // initial epsilon for epsilon-greedy policy, [0, 1)
spec.alpha = 0.01; // value function learning rate
spec.experience_add_every = 10; // number of time steps before we add another experience to replay memory
spec.experience_size = 5000; // size of experience replay memory
spec.learning_steps_per_iteration = 20;
spec.tderror_clamp = 1.0; // for robustness
spec.num_hidden_units = 100 // number of neurons in hidden layer
</textarea>

    <button class="btn btn-danger" onclick="resetAgent()" style="width:150px;height:50px;margin-bottom:5px;">Reinit agent</button>
    <button class="btn btn-primary" onclick="togglelearn()" style="width:150px;height:50px;margin-bottom:5px;">Toggle Run</button>
    <button class="btn btn-success" onclick="gofast()" style="width:150px;height:50px;margin-bottom:5px;">Go fast</button>
    <button class="btn btn-success" onclick="gonormal()" style="width:150px;height:50px;margin-bottom:5px;">Go normal</button>
    <button class="btn btn-success" onclick="goslow()" style="width:150px;height:50px;margin-bottom:5px;">Go slow</button>

    <br> Exploration epsilon: <span id="eps">0.15</span>
    <div id="slider"></div>
    <br>

    <div id="draw"></div>


    <br>
    <div style="float:right;">
      <button class="btn btn-primary" onclick="loadAgent()" style="width:200px;height:35px;margin-bottom:5px;margin-right:20px;">Load a Pretrained Agent</button>
    </div>
    <textarea id="mysterybox" style="width:100%;display:none;">mystery text box</textarea>

    <div> Experience write pointer:
      <div id="expi" style="display:inline-block;"></div>
    </div>
    <div> Latest TD error:
      <div id="tde" style="display:inline-block;"></div>
    </div>
    <div> Latest Q(s,a):
      <div id="qsa" style="display:inline-block;"></div>
    </div>
    <div id="tdest"></div>

    <div>
      <div style="text-align:center">平均奖励 (high is good)</div>
      <div id="flotreward" style="width:800px; height: 400px;"></div>
    </div>

   

  <br><br><br><br><br>
</body>

</html>